#! /usr/bin/env python3

import argparse
import logging
import os
import numpy as np

import misc.utils as utils

import urlearning.ScoreCache as ScoreCache
import urlearning.TarjansAlgorithm as TarjansAlgorithm
import urlearning.TopPPopsConstraint as TopPPopsConstraint

default_time = 0
default_pd_file = ""
default_net = ""

def get_indices_string(variables):
    return ",".join( str(int(i)) for i in np.nonzero(variables)[0])

def main():
    
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("pss", help="pss file used to find the structure (input)")
    parser.add_argument("-t", "--time", help="Time limit for the calls to astar. If t < 1, "
        "then no limit is given.", type=int, default=default_time)
    parser.add_argument("-p", "--pd-file", help="Precalculated pattern database file.", 
        default=default_pd_file)
    parser.add_argument("-n", "--net", help="The output file containing the learned "
        "network structure.", default=default_net)
    
    utils.add_logging_options(parser)
    args = parser.parse_args()
    utils.update_logging(args)

    programs =  [   'astar',
                    'merge-networks'
                ]
    utils.check_programs_exist(programs)

    msg = "Reading the score cache"
    logging.info(msg)

    sc = ScoreCache.ScoreCache(args.pss)

    msg = "Creating the parent relation graph"
    logging.info(msg)

    tppc = TopPPopsConstraint.TopPPopsConstraint()
    parent_relation_graph = tppc.createConstrainedGraph(sc, p=0)

    msg = "Extracting the ancestor constraints"
    logging.info(msg)

    tarjans = TarjansAlgorithm.TarjansAlgorithm()
    sccs = tarjans.getSCCs(parent_relation_graph)

    ancestors = np.zeros( len(sc) )

    total_score = 0
    total_expanded = 0
    total_wall = 0
    total_user = 0
    total_system = 0

    dataset = args.pss

    time_argument = ""
    if args.time > 0:
        time_argument = "-r {}".format(args.time)

    if len(args.pd_file) == 0:
        pd_argument = "-e static"
    else:
        pd_argument = "-e file -a {}".format(args.pd_file)

    for i,scc in enumerate(sccs):
        scc_string = get_indices_string(scc)
        scc_indices = np.nonzero(scc)[0]
        scc_argument = "-s {}".format(scc_string)

        msg = "Considering scc: '{}'".format(scc_string)
        logging.debug(msg)

        ancestors_string = get_indices_string(ancestors)
        ancestors_argument = ""
        if len(ancestors_string) > 0:
            ancestors_argument = "-p {}".format(ancestors_string)

        net_argument = ""
        if len(args.net) > 0:
            net_argument = "-n {}.scc{}".format(args.net, i)

        cmd = "astar {} {} {} {} {} {}".format(args.pss, ancestors_argument, scc_argument, 
            pd_argument, time_argument, net_argument)

        res = utils.check_output(cmd)
        logging.debug(res)
        
        s = res.split("\n")

        found_solution = False
        last_time_line = ""
        for line in s:
            if line.startswith("Found solution:"):
                score = float(line.partition(": ")[2])
                total_score += score
                found_solution = True
                
                msg = "Score: {}".format(score)
                logging.info(msg)

            elif line.startswith("Nodes expanded:"):
                expanded = int(line.split(" ")[2][:-1])
                total_expanded += expanded

                msg = "Expanded: {}".format(expanded)
                logging.info(msg)

            elif "wall" in line:
                last_time_line = line.strip()

        # parse out the times
        # "3.319937s wall, 3.170000s user + 0.140000s system = 3.310000s CPU (99.7%)"
        s = last_time_line.split(" ")
        if len(s) == 11:
            wall = float(s[0][:-1])
            user = float(s[2][:-1])
            system = float(s[5][:-1])

            logging.info("Wall: {}".format(wall))
            logging.info("User: {}".format(user))
            logging.info("System: {}".format(system))

            total_wall += wall
            total_user += user
            total_system += system
        
        if not found_solution:
            msg = "The astar call failed. Quitting."
            logging.error(msg)
            break

        ancestors = np.maximum(ancestors, scc)

    print("Timing for astar calls")
    print("Total wall time: {}s".format(total_wall))
    print("Total user time: {}s".format(total_user))
    print("Total system time: {}s".format(total_system))

    if found_solution:
        print("Total score: {}".format(total_score))
        print("Total expanded: {}".format(total_expanded))

    if len(args.net) > 0:
        cmd = "merge-networks {} {}".format( "{}.scc*".format(args.net), args.net)
        utils.check_call(cmd)

if __name__ == '__main__':
    main()
